{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Alphabet(dict):\n",
    "    def __init__(self, start_id=1):\n",
    "        self.fid = start_id\n",
    "    \n",
    "    def add(self, item):\n",
    "        idx = self.get(item, None)\n",
    "        if idx is None:\n",
    "            idx = self.fid\n",
    "            self[item] = idx\n",
    "            self.fid += 1\n",
    "        return idx\n",
    "    \n",
    "    def dump(self, fname):\n",
    "        with open(fname, 'w') as out:\n",
    "            for k in sorted(self.keys()):\n",
    "                out.write(\"{}\\t{}\\n\".format(k, self[k]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def load_embed(fname):\n",
    "    f = open(fname)\n",
    "    cnt, vocab_size, embed_dim = 0, 0, 0\n",
    "    embed_dict = {}\n",
    "    print 'Load embedding file start!'\n",
    "    for line in f:\n",
    "        cnt += 1\n",
    "        if cnt % 10000 == 0:\n",
    "            print cnt\n",
    "        terms = line.strip().split(' ')\n",
    "        if len(terms) == 2:\n",
    "            vocab_size = int(terms[0])\n",
    "            embed_dim = int(terms[1])\n",
    "        if len(terms) == embed_dim + 1:\n",
    "            embed_dict[terms[0]] = np.array([float(ii) for ii in terms[1:]])\n",
    "    print 'Load embedding file finish!'\n",
    "    return embed_dict, vocab_size, embed_dim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def parse_embed(embed_file):\n",
    "    alphabet, embed_mat = Alphabet(), []\n",
    "    embed_dict, vocab_size, embed_dim = load_embed(embed_file)\n",
    "    unknown_word_idx = 0\n",
    "    embed_mat.append(np.random.uniform(-0.25, 0.25, embed_dim))\n",
    "    for word in embed_dict:\n",
    "        alphabet.add(word)\n",
    "        embed_mat.append(embed_dict[word])\n",
    "    dummy_word_idx = alphabet.fid\n",
    "    embed_mat.append(np.zeros(embed_dim))\n",
    "    return alphabet, embed_mat, unknown_word_idx, dummy_word_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Load embedding file start!\n",
      "10000\n",
      "Load embedding file finish!\n",
      "Load embedding file start!\n",
      "10000\n",
      "20000\n",
      "30000\n",
      "40000\n",
      "50000\n",
      "60000\n",
      "70000\n",
      "80000\n",
      "90000\n",
      "100000\n",
      "110000\n",
      "120000\n",
      "130000\n",
      "140000\n",
      "150000\n",
      "160000\n",
      "170000\n",
      "180000\n",
      "190000\n",
      "200000\n",
      "210000\n",
      "220000\n",
      "230000\n",
      "240000\n",
      "250000\n",
      "260000\n",
      "270000\n",
      "280000\n",
      "290000\n",
      "300000\n",
      "310000\n",
      "320000\n",
      "330000\n",
      "340000\n",
      "350000\n",
      "360000\n",
      "370000\n",
      "380000\n",
      "390000\n",
      "400000\n",
      "410000\n",
      "Load embedding file finish!\n"
     ]
    }
   ],
   "source": [
    "char_embed_file = '../ieee_zhihu_cup/char_embedding.txt'\n",
    "word_embed_file = '../ieee_zhihu_cup/word_embedding.txt'\n",
    "char_alphabet, char_embed_mat, unknown_char_idx, dummy_char_idx = parse_embed(char_embed_file)\n",
    "word_alphabet, word_embed_mat, unknown_word_idx, dummy_word_idx = parse_embed(word_embed_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def load_topic(fname):\n",
    "    f = open(fname)\n",
    "    cnt, idx, pidx, title_char, title_word, desc_char, desc_word = 0, [], [], [], [], [], []\n",
    "    print \"Load topic start!\"\n",
    "    for line in f:\n",
    "        cnt += 1\n",
    "        if cnt % 100 == 0:\n",
    "            print cnt\n",
    "        terms = line.strip().split('\\t')\n",
    "        idx.append(terms[0])\n",
    "        if len(terms) == 6:\n",
    "            pidx.append(terms[1])\n",
    "            title_char.append(terms[2])\n",
    "            title_word.append(terms[3])\n",
    "            desc_char.append(terms[4])\n",
    "            desc_word.append(terms[5])\n",
    "        elif len(terms) == 5:\n",
    "            pidx.append(terms[1])\n",
    "            title_char.append(terms[2])\n",
    "            title_word.append(terms[3])\n",
    "            desc_char.append(terms[4])\n",
    "            desc_word.append('')\n",
    "        elif len(terms) == 4:\n",
    "            pidx.append(terms[1])\n",
    "            title_char.append(terms[2])\n",
    "            title_word.append(terms[3])\n",
    "            desc_char.append('')\n",
    "            desc_word.append('')\n",
    "    print \"Load topic finish!\"\n",
    "    return idx, pidx, title_char, title_word, desc_char, desc_word"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def getTopicid(data, idx_dict):\n",
    "    data_idx = []\n",
    "    for item in data:\n",
    "        item_arr = [ii for ii in item.split(',') if ii != '']\n",
    "        ex = np.zeros(len(item_arr))\n",
    "        for i, topic in enumerate(item_arr):\n",
    "            idx = idx_dict.get(topic)\n",
    "            ex[i] = idx\n",
    "        data_idx.append(ex.astype('int32'))\n",
    "    return data_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def convert2indices(data, alphabet, unknown_word_idx, dummy_word_idx, max_length):\n",
    "    data_idx = []\n",
    "    for item in data:\n",
    "        item_arr = [ii for ii in item.split(',') if ii != '']\n",
    "        ex = np.ones(max_length) * dummy_word_idx\n",
    "        for i, word in enumerate(item_arr):\n",
    "            if i >= max_length:\n",
    "                break\n",
    "            idx = alphabet.get(word, unknown_word_idx)\n",
    "            ex[i] = idx\n",
    "        data_idx.append(ex)\n",
    "    data_idx = np.array(data_idx).astype('int32')\n",
    "    return data_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100\n",
      "200\n",
      "300\n",
      "400\n",
      "500\n",
      "600\n",
      "700\n",
      "800\n",
      "900\n",
      "1000\n",
      "1100\n",
      "1200\n",
      "1300\n",
      "1400\n",
      "1500\n",
      "1600\n",
      "1700\n",
      "1800\n",
      "1900\n"
     ]
    }
   ],
   "source": [
    "topic_file = '../ieee_zhihu_cup/topic_info.txt'\n",
    "topic_idx, topic_pidx, topic_title_char, topic_title_word, topic_desc_char, \\\n",
    "                                topic_desc_word = load_topic(topic_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "topic_idx_dict = {tid:i for i, tid in enumerate(topic_idx)}\n",
    "topic_pidx_indices = getTopicid(topic_pidx, topic_idx_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "topic_title_char_max_length = 4\n",
    "topic_title_word_max_length = 1\n",
    "topic_desc_char_max_length = 64\n",
    "topic_desc_word_max_length = 31"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "topic_title_char_indices = convert2indices(topic_title_char, char_alphabet, \\\n",
    "                                           unknown_char_idx, dummy_char_idx, \\\n",
    "                                           topic_title_char_max_length)\n",
    "topic_title_word_indices = convert2indices(topic_title_word, word_alphabet, \\\n",
    "                                           unknown_word_idx, dummy_word_idx, \\\n",
    "                                           topic_title_word_max_length)\n",
    "topic_desc_char_indices = convert2indices(topic_desc_char, char_alphabet, \\\n",
    "                                           unknown_char_idx, dummy_char_idx, \\\n",
    "                                           topic_desc_char_max_length)\n",
    "topic_desc_word_indices = convert2indices(topic_desc_word, word_alphabet, \\\n",
    "                                           unknown_word_idx, dummy_word_idx, \\\n",
    "                                           topic_desc_word_max_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "basedir = './topic'\n",
    "np.save('{}/topic_idx.npy'.format(basedir), topic_idx)\n",
    "np.save('{}/topic_pidx_indices.npy'.format(basedir), topic_pidx_indices)\n",
    "np.save('{}/topic_title_char_indices.npy'.format(basedir), topic_title_char_indices)\n",
    "np.save('{}/topic_title_word_indices.npy'.format(basedir), topic_title_word_indices)\n",
    "np.save('{}/topic_desc_char_indices.npy'.format(basedir), topic_desc_char_indices)\n",
    "np.save('{}/topic_desc_word_indices.npy'.format(basedir), topic_desc_word_indices)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
